#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@author:hengk
@contact: hengk@foxmail.com
@datetime:2019-10-30 14:41
"""
import torch.nn as nn
import torch


class CrossEntroy(nn.Module):
    def forward(self,input,target):
        input = torch.sigmoid(input)
        sum = -target*torch.log(input) - (1-target)*torch.log(1-input)
        return torch.mean(sum)
class LossFactory(object):
    def __init__(self):
        pass
    @staticmethod
    def create(name):
        if(name =="ce"):
            return CrossEntroy()
        print("没有"+name+"损失函数")